{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "(https://arxiv.org/pdf/2104.09864) [Roformer: Enhanced Transformer With Rotary\n",
    "Position Embedding]\n",
    "\n",
    "I learned how RoPE works from https://huggingface.co/blog/designing-positional-encoding\n",
    "the article motivates how it was developed particularly well. I encourage you read it to get intuition for \n",
    "things like how, where and why there's rotation happening, and the key desiderata for positional embeddings that lead naturally \n",
    "first to sinusoidal embeddings, and then to RoPE after that. \n",
    "    A crucial observation motivating positional embeddings is that attentions is a *set* operation\n",
    "    ie. does not use any positional information, just pairwise comparison between tokens \n",
    "    independent of their position, so if you don't put positional information in the embeddings, \n",
    "    then the model that *no way of telling apart two identical words at different places in the sequence \n",
    "    ie, no way to exploiting context to infer semantic meaning* and so you will *cripple the model*. \n",
    "\n",
    "\n",
    "We start by implementing sinusoidal embeddings from \"Attention is all you need\" as a baseline\n",
    "these are applied to [batch_size, hidden_dim] matrices representing tokens after \n",
    "they are embedded from tokens to hidden_dim in the Transformer embedding layer \n",
    "\n",
    "In contrast, RoPE adds to the (q,k) attention matrices directly, since only those affect inter-token \n",
    "computation and thus use positional information across tokens. \n",
    "'''\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 64, 768])"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch \n",
    "import torch.nn as nn \n",
    "import torch.nn.functional as F\n",
    "\n",
    "b, s, d = 4, 64, 768 \n",
    "embeddings = torch.randn(b, s, d)\n",
    "\n",
    "# sinusoidal embeddings, like more pos_embeds, operates on the token embeddings and not attn (like rope)\n",
    "def get_sinusoidal_embeddings(s, d): \n",
    "    pos_embeddings = torch.zeros(s, d)\n",
    "    base = torch.tensor(10_000)\n",
    "    div_even = torch.exp(torch.log(base) * 2 * torch.arange(0, d//2)/d)\n",
    "    div_odd = torch.exp(torch.log(base) * 2 * torch.arange(1, d//2+1)/d)\n",
    "    pos_embeddings[:, 0::2] = (torch.arange(s)[:,None] / div_even) # want [s, d//2]\n",
    "    pos_embeddings[:, 1::2] = (torch.arange(s)[:,None] / div_odd)\n",
    "    return pos_embeddings # [s, d]\n",
    "\n",
    "# do embeddings + get_sinusoidal_embeddings(s, d) and it'll broadcast over batch\n",
    "# RoPE, unlike traditional positional embeddings, operates on q, k in attention directly\n",
    "def add_rope_embeddings(q, k):\n",
    "    b, s, d = q.shape\n",
    "\n",
    "    thetas = 1.0 / (10000 ** (torch.arange(0, d // 2) / d // 2))\n",
    "\n",
    "    # Compute sin and cos for each position\n",
    "    positions = torch.arange(s)\n",
    "    freqs = positions[:, None] * thetas # [s, d // 2], uses broadcasting to coming [s] * [d//2] outer product \n",
    "    sin, cos = freqs.sin(), freqs.cos() # [s, d // 2]\n",
    "\n",
    "    sin = sin.repeat_interleave(2, dim=-1).unsqueeze(0) # [1, s, d]\n",
    "    cos = cos.repeat_interleave(2, dim=-1).unsqueeze(0) # [1, s, d]\n",
    "\n",
    "    # rotate every pair of features (x1, x2) -> (-x2, x1)\n",
    "    def _rotate_every_two(x: torch.Tensor) -> torch.Tensor:\n",
    "        x = x.reshape(b, s, d // 2, 2)\n",
    "        x1, x2 = x.unbind(-1)\n",
    "        return torch.stack((-x2, x1), dim=-1).reshape(b, s, d)\n",
    "\n",
    "    q_rot = (q * cos) + (_rotate_every_two(q) * sin)\n",
    "    k_rot = (k * cos) + (_rotate_every_two(k) * sin)\n",
    "    return q_rot, k_rot\n",
    "\n",
    "\n",
    "q, k = torch.randn(b, s, d), torch.randn(b, s, d)\n",
    "add_rope_embeddings(q, k)[0].shape  # [b, s, d]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "envi",
   "language": "python",
   "name": "lingua_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
